{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fZ_xQvU70UQc"
      },
      "source": [
        "# Tune-A-Video\n",
        "\n",
        "**[Tune-A-Video: One-Shot Tuning of Image Diffusion Models for Text-to-Video Generation](https://arxiv.org/abs/2212.11565)**  \n",
        "[Jay Zhangjie Wu](https://zhangjiewu.github.io/), \n",
        "[Yixiao Ge](https://geyixiao.com/), \n",
        "[Xintao Wang](https://xinntao.github.io/), \n",
        "[Stan Weixian Lei](), \n",
        "[Yuchao Gu](https://ycgu.site/), \n",
        "[Wynne Hsu](https://www.comp.nus.edu.sg/~whsu/), \n",
        "[Ying Shan](https://scholar.google.com/citations?user=4oXBp9UAAAAJ&hl=en), \n",
        "[Xiaohu Qie](https://scholar.google.com/citations?user=mk-F69UAAAAJ&hl=en), \n",
        "[Mike Zheng Shou](https://sites.google.com/view/showlab)  \n",
        "\n",
        "[![Project Website](https://img.shields.io/badge/Project-Website-orange)](https://tuneavideo.github.io/)\n",
        "[![arXiv](https://img.shields.io/badge/arXiv-2212.11565-b31b1b.svg)](https://arxiv.org/abs/2212.11565)\n",
        "[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/Tune-A-Video-library/Tune-A-Video-Training-UI)\n",
        "[![GitHub](https://img.shields.io/github/stars/showlab/Tune-A-Video?style=social)](https://github.com/showlab/Tune-A-Video)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "wnTMyW41cC1E"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "XU7NuMAA2drw"
      },
      "outputs": [],
      "source": [
        "#@markdown Check type of GPU and VRAM available.\n",
        "!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "D1PRgre3Gt5U"
      },
      "outputs": [],
      "source": [
        "#@title Install requirements\n",
        "\n",
        "!git clone https://github.com/showlab/Tune-A-Video.git /content/Tune-A-Video\n",
        "%cd /content/Tune-A-Video \n",
        "# %pip install -r requirements.txt\n",
        "%pip install -q -U --pre triton\n",
        "%pip install -q diffusers[torch]==0.11.1 transformers==4.26.0 bitsandbytes==0.35.4 \\\n",
        "decord accelerate omegaconf einops ftfy gradio imageio-ffmpeg xformers"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "m6I6kZNG3Inb"
      },
      "outputs": [],
      "source": [
        "#@title Download pretrained model\n",
        "\n",
        "#@markdown Name/Path of the initial model.\n",
        "MODEL_NAME = \"CompVis/stable-diffusion-v1-4\" #@param {type:\"string\"}\n",
        "\n",
        "#@markdown If model should be download from a remote repo. Untick it if the model is loaded from a local path.\n",
        "download_pretrained_model = True #@param {type:\"boolean\"}\n",
        "if download_pretrained_model:\n",
        "    !git lfs install\n",
        "    !git clone https://huggingface.co/$MODEL_NAME checkpoints/$MODEL_NAME\n",
        "    MODEL_NAME = f\"./checkpoints/{MODEL_NAME}\"\n",
        "print(f\"[*] MODEL_NAME={MODEL_NAME}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qn5ILIyDJIcX"
      },
      "source": [
        "## Usage\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "REmFAHfz9Y_X"
      },
      "source": [
        "### Training\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Rxg0y5MBudmd"
      },
      "outputs": [],
      "source": [
        "#@markdown If model weights should be saved directly in google drive (takes around 4-5 GB).\n",
        "save_to_gdrive = False #@param {type:\"boolean\"}\n",
        "if save_to_gdrive:\n",
        "    from google.colab import drive\n",
        "    drive.mount('/content/drive')\n",
        "\n",
        "#@markdown Enter the directory name to save model at.\n",
        "\n",
        "OUTPUT_DIR = \"outputs/man-skiing\" #@param {type:\"string\"}\n",
        "if save_to_gdrive:\n",
        "    OUTPUT_DIR = \"/content/drive/MyDrive/\" + OUTPUT_DIR\n",
        "\n",
        "print(f\"[*] Weights will be saved at {OUTPUT_DIR}\")\n",
        "\n",
        "!mkdir -p $OUTPUT_DIR\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "32gYIDDR1aCp"
      },
      "outputs": [],
      "source": [
        "#@markdown Upload your video by running this cell.\n",
        "\n",
        "#@markdown OR\n",
        "\n",
        "#@markdown You can use the file manager on the left panel to upload (drag and drop) to `data` folder.\n",
        "\n",
        "import os\n",
        "from google.colab import files\n",
        "import shutil\n",
        "\n",
        "uploaded = files.upload()\n",
        "for filename in uploaded.keys():\n",
        "    dst_path = os.path.join(\"data\", filename)\n",
        "    shutil.move(filename, dst_path)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wGGFFpNcR2d_"
      },
      "outputs": [],
      "source": [
        "#@markdown Train config\n",
        "\n",
        "from omegaconf import OmegaConf\n",
        "\n",
        "CONFIG_NAME = \"configs/man-skiing.yaml\" #@param {type:\"string\"}\n",
        "\n",
        "train_video_path = \"data/man-skiing.mp4\" #@param {type:\"string\"}\n",
        "train_prompt = \"a man is skiing\" #@param {type:\"string\"}\n",
        "video_length = 8 #@param {type:\"number\"}\n",
        "width = 512 #@param {type:\"number\"}\n",
        "height = 512 #@param {type:\"number\"}\n",
        "learning_rate = 3e-5 #@param {type:\"number\"}\n",
        "train_steps = 300 #@param {type:\"number\"}\n",
        "\n",
        "config = {\n",
        "  \"pretrained_model_path\": MODEL_NAME,\n",
        "  \"output_dir\": OUTPUT_DIR,\n",
        "  \"train_data\": {\n",
        "    \"video_path\": train_video_path,\n",
        "    \"prompt\": train_prompt,\n",
        "    \"n_sample_frames\": video_length,\n",
        "    \"width\": width,\n",
        "    \"height\": height,\n",
        "    \"sample_start_idx\": 0,\n",
        "    \"sample_frame_rate\": 2,\n",
        "  },\n",
        "  \"validation_data\": {\n",
        "    \"prompts\": [\n",
        "      \"mickey mouse is skiing on the snow\",\n",
        "      \"spider man is skiing on the beach, cartoon style\",\n",
        "      \"wonder woman, wearing a cowboy hat, is skiing\",\n",
        "      \"a man, wearing pink clothes, is skiing at sunset\",\n",
        "    ],\n",
        "    \"video_length\": video_length,\n",
        "    \"width\": width,\n",
        "    \"height\": height,\n",
        "    \"num_inference_steps\": 20,\n",
        "    \"guidance_scale\": 12.5,\n",
        "    \"use_inv_latent\": True,\n",
        "    \"num_inv_steps\": 50,\n",
        "  },\n",
        "  \"learning_rate\": learning_rate,\n",
        "  \"train_batch_size\": 1,\n",
        "  \"max_train_steps\": train_steps,\n",
        "  \"checkpointing_steps\": 1000,\n",
        "  \"validation_steps\": 100,\n",
        "  \"trainable_modules\": [\n",
        "    \"attn1.to_q\",\n",
        "    \"attn2.to_q\",\n",
        "    \"attn_temp\",\n",
        "  ],\n",
        "  \"seed\": 33,\n",
        "  \"mixed_precision\": \"fp16\",\n",
        "  \"use_8bit_adam\": False,\n",
        "  \"gradient_checkpointing\": True,\n",
        "  \"enable_xformers_memory_efficient_attention\": True,\n",
        "}\n",
        "\n",
        "OmegaConf.save(config, CONFIG_NAME)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jjcSXTp-u-Eg"
      },
      "outputs": [],
      "source": [
        "!accelerate launch train_tuneavideo.py --config=$CONFIG_NAME"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ToNG4fd_dTbF"
      },
      "source": [
        "### Inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "91bsSFv2Punm"
      },
      "outputs": [],
      "source": [
        "import torch\n",
        "from torch import autocast\n",
        "from diffusers import DDIMScheduler\n",
        "from tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline\n",
        "from tuneavideo.models.unet import UNet3DConditionModel\n",
        "from tuneavideo.util import save_videos_grid\n",
        "\n",
        "\n",
        "unet = UNet3DConditionModel.from_pretrained(OUTPUT_DIR, subfolder='unet', torch_dtype=torch.float16).to('cuda')\n",
        "scheduler = DDIMScheduler.from_pretrained(MODEL_NAME, subfolder='scheduler')\n",
        "pipe = TuneAVideoPipeline.from_pretrained(MODEL_NAME, unet=unet, scheduler=scheduler, torch_dtype=torch.float16).to(\"cuda\")\n",
        "pipe.enable_xformers_memory_efficient_attention()\n",
        "pipe.enable_vae_slicing()\n",
        "\n",
        "g_cuda = None"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "oIzkltjpVO_f"
      },
      "outputs": [],
      "source": [
        "#@markdown Can set random seed here for reproducibility.\n",
        "g_cuda = torch.Generator(device='cuda')\n",
        "seed = 1234 #@param {type:\"number\"}\n",
        "g_cuda.manual_seed(seed)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "K6xoHWSsbcS3",
        "scrolled": false
      },
      "outputs": [],
      "source": [
        "#@markdown Run for generating videos.\n",
        "\n",
        "prompt = \"iron man is skiing\" #@param {type:\"string\"}\n",
        "negative_prompt = \"\" #@param {type:\"string\"}\n",
        "use_inv_latent = True #@param {type:\"boolean\"}\n",
        "inv_latent_path = \"\" #@param {type:\"string\"}\n",
        "num_samples = 1 #@param {type:\"number\"}\n",
        "guidance_scale = 12.5 #@param {type:\"number\"}\n",
        "num_inference_steps = 50 #@param {type:\"number\"}\n",
        "video_length = 8 #@param {type:\"number\"}\n",
        "height = 512 #@param {type:\"number\"}\n",
        "width = 512 #@param {type:\"number\"}\n",
        "\n",
        "ddim_inv_latent = None\n",
        "if use_inv_latent and inv_latent_path == \"\":\n",
        "    from natsort import natsorted\n",
        "    from glob import glob\n",
        "    import os\n",
        "    inv_latent_path = natsorted(glob(f\"{OUTPUT_DIR}/inv_latents/*\"))[-1]\n",
        "    ddim_inv_latent = torch.load(inv_latent_path).to(torch.float16)\n",
        "    print(f\"DDIM inversion latent loaded from {inv_latent_path}\")\n",
        "\n",
        "with autocast(\"cuda\"), torch.inference_mode():\n",
        "    videos = pipe(\n",
        "        prompt, \n",
        "        latents=ddim_inv_latent,\n",
        "        video_length=video_length, \n",
        "        height=height, \n",
        "        width=width, \n",
        "        negative_prompt=negative_prompt,\n",
        "        num_videos_per_prompt=num_samples,\n",
        "        num_inference_steps=num_inference_steps, \n",
        "        guidance_scale=guidance_scale,\n",
        "        generator=g_cuda\n",
        "    ).videos\n",
        "\n",
        "save_dir = \"./results\" #@param {type:\"string\"}\n",
        "save_path = f\"{save_dir}/{prompt}.gif\"\n",
        "save_videos_grid(videos, save_path)\n",
        "\n",
        "# display\n",
        "from IPython.display import Image, display\n",
        "display(Image(filename=save_path))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "jXgi8HM4c-DA"
      },
      "outputs": [],
      "source": [
        "#@markdown Free runtime memory\n",
        "exit()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3 (ipykernel)",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.8.13-final"
    },
    "vscode": {
      "interpreter": {
        "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a"
      }
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}